import pybullet as p
import pybullet_data
import pybullet_utils.bullet_client as bc
import time
import random
import numpy as np
import os
from src.bulletRobots import robotsDict
import gym
import utils.misc

EPS = 1e-3

class BulletEnv(gym.Env):
    def __init__(self,
                 seed=1337,
                 gui=False,
                 init_set=None,
                 term_set=None,
                 term_sampler=None,
                 option_guide=None,
                 region_switch_point=None,
                 demo_mode=False,
                 robot_config=None,
                 env_path=None,
                 forked=True,
                 envid='default',
                 max_ep_len=100):
        # np.random.seed(seed=seed)
        self.set_seed(seed=seed)
        if not gui:
            self.mode = p.DIRECT
        else:
            self.mode = p.GUI
        
        self.forked = forked
        self.env_id = envid
        self.time_steps = 2400
        self.interval = 240000.
        self.env_mode='train'
        self.reward = 0
        self.done = False
        self.action_count = 0
        self.bodies = {}
        self.max_ep_len = max_ep_len
        self.init_set = init_set
        self.term_set = term_set
        self.term_sampler = term_sampler
        self.local_target = None
        self.total_timesteps=0
        self.motion_clip_factor = 1
        self.robot_config = robot_config
        self.demo_mode = demo_mode
        self.episode_step_ctr = 0
        self.env_path = env_path
        self.last_pose = None
        self.log_arr = []
        self.update_targets(option_guide, region_switch_point,gui_off=True)
        self.info = utils.misc.create_env_info_dict()
        self.ulimits = np.asarray(self.robot_config['ulimits'])
        self.llimits = np.asarray(self.robot_config['llimits'])
        # self.llimits = np.asarray([-2.5,-2.5,-3.14])

        # self.observation_space = gym.spaces.Box(low=np.array([self.robot_config['llimits'][0],
        #                                                       self.robot_config['llimits'][1],
        #                                                       -3.14,
        #                                                       0.0]),
        #                                         high=np.array([self.robot_config['ulimits'][0],
        #                                                        self.robot_config['ulimits'][1],
        #                                                        3.14,
        #                                                        0.5]))
        # self.observation_space = gym.spaces.Box(low=np.array(self.robot_config['llimits']+[0.0]),
                                                # high=np.array(self.robot_config['ulimits']+[0.0]))

        self.action_space = robotsDict[self.robot_config['name']].action_space

        low = self.llimits.tolist() + self.action_space.low.tolist()
        high = self.ulimits.tolist() + self.action_space.high.tolist()

        self.observation_space = gym.spaces.Box(low=np.array(low),
                                                high=np.array(high))




        if not self.forked:
            self.start_pybullet()
        if self.init_set is None or self.term_set is None:
            raise Exception("init_set and term_set must not be None")

    def set_seed(self, seed):
        np.random.seed(seed)
        random.seed(seed)

    def streamDebug(self, msg, **kwargs):

        print(__name__+'-'+str(self.env_id)+":"+msg, **kwargs)


    def set_views(self):
        self._p.configureDebugVisualizer(p.COV_ENABLE_SHADOWS,0)
        self._p.configureDebugVisualizer(p.COV_ENABLE_MOUSE_PICKING,0)
        self._p.configureDebugVisualizer(self._p.COV_ENABLE_GUI, 0)
        
        self._p.resetDebugVisualizerCamera(cameraDistance=5.0,
                                           cameraYaw=270.,
                                           cameraPitch=-90.4,
                                           cameraTargetPosition=[0.0, 0.0, 0.0],
                                           physicsClientId=self.phys)


    def start_pybullet(self):
        self._p = bc.BulletClient(connection_mode=self.mode)
        self.phys = self._p._client
        self.set_views()
        self._p.setAdditionalSearchPath(pybullet_data.getDataPath())
        self._p.setGravity(gravX=0, gravY=0, gravZ=-10, physicsClientId=self.phys)
        if self.demo_mode:
            pass
        else:
            self.bodies['env'] = self.load_stl(os.path.abspath(self.env_path))

        self.load_robot()
        if self.mode == p.GUI and self.option_guide is not None:
            n=0
            while n<10:
                n+=1
                self.plot_angles(self.term_sampler.sample(),color=[1,0,0])
                self.plot_angles(self.init_set.sample(),color=[1,1,0])
            self.plot_guide(self.option_guide)
        # self.debugTools()

    def plot_guide(self,guide):
        if guide is not None:
            for point in guide.data:
                self._p.addUserDebugLine([point[0],point[1],0.0],[point[0],point[1],0.5],[0,0,1])


    def debugTools(self):

        startPos = self.init_set.sample()
        endPos = self.term_set.pos

        self._p.addUserDebugLine([startPos[0],0.,0.],[startPos[0]+0.5,0.,0.],[1,0,0])
        self._p.addUserDebugLine([0.,startPos[1],0.],[0.,startPos[1]+0.5,0.],[0,1,0])
        self._p.addUserDebugLine([0.,0.,0.0],[0.,0.,0.5],[0,0,1])

        self._p.addUserDebugLine([endPos[0],0.,0.],[endPos[0]+0.5,0.,0.],[1,0,0])
        self._p.addUserDebugLine([0.,endPos[1],0.],[0.,endPos[1]+0.5,0.],[0,1,0])
        self._p.addUserDebugLine([0.,0.,0.0],[0.,0.,0.5],[0,0,1])

    def load_robot(self):
        self.streamDebug("Load robot...")
        # self.start_pose_robot = self.get_start_config()
        self.start_pose_robot = self.get_active_DOF_start_config()
        self.robot = robotsDict[self.robot_config['name']](bullet_client=self._p,
                                                           physclient=self.phys,
                                                           urdf_path=os.path.join(os.path.abspath(__file__+'../../../data'),self.robot_config['model_path']),
                                                           start_pos=self.start_pose_robot[0] # Set up the robot base 
                                                           )
        
        if self.robot is None:
            raise ValueError(f"robot be must be one of {list(robotsDict.keys())}")

        self.bodies['robot'] = self.robot.bodyId

        self.setPose(self.start_pose_robot)
        self.streamDebug("Loaded robot")
    
    def collision_at_pose(self, pose):
        # original_pose = self.getRobotCurState()
        original_pose = self.robot.getPose()
        self.robot.setPose(pose)
        pt, collision = self.check_body_collision(body1=self.bodies['robot'],
                                     body2=self.bodies['env'])
        self.robot.setPose(original_pose)
        return collision


    def check_body_collision(self, body1, body2, linkIndex1=-1, linkIndex2=-1):

        if isinstance(linkIndex1, list):
            closestPts = []
            for index in linkIndex1:
                idx_closest = self._p.getClosestPoints(bodyA=body1,
                                        bodyB=body2,
                                        linkIndexA=index,
                                        linkIndexB=linkIndex2,
                                        distance=0.08)
                closestPts.extend(idx_closest)
        else:
            closestPts = self._p.getClosestPoints(bodyA=body1,
                                              bodyB=body2,
                                              linkIndexA=linkIndex1,
                                              linkIndexB=linkIndex2,
                                              distance=0.08)
        minDistance = 0.01
        if len(closestPts) == 0:
            return minDistance, False
        # else:
        #     for pt in closestPts:
        #         if pt[8] < minDistance:
        #             minDistance = pt[8]
        #         if pt[8] < 0:
        #             return  minDistance, True
        return  minDistance, True

    def load_stl(self, path):

        self.bodies['plane'] = self._p.loadURDF(fileName='plane.urdf',
                                          basePosition=[0,0,0],
                                          baseOrientation=self._p.getQuaternionFromEuler([0,0,0]),
                                          physicsClientId=self.phys)
        col_shape_id = self._p.createCollisionShape(
                        shapeType=p.GEOM_MESH,
                        fileName=path,
                        flags=p.URDF_INITIALIZE_SAT_FEATURES|p.GEOM_FORCE_CONCAVE_TRIMESH
                    )
        viz_shape_id = self._p.createCollisionShape(
                        shapeType=p.GEOM_MESH,
                        fileName=path,
                    )

        body_id = self._p.createMultiBody(
                    baseVisualShapeIndex=viz_shape_id,
                    baseCollisionShapeIndex=col_shape_id,
                    basePosition=(0, 0, -0.25),
                    baseOrientation=(0, 0, 0, 1),
                )

        return body_id

    def set_sampler_and_eval_func(self,sampler,eval_func,term_sampler):
        self.init_set = sampler
        self.term_set = eval_func
        self.term_sampler = term_sampler
    

    def get_start_config(self):
        
        # if self.option_guide is not None:
        #     # Zero theta might not work in tight spaces. 
        #     # Try to rotate the robot base into a random position
        #     if len(self.option_guide.data)>1:
        #         theta = 3.14*2 * np.random.random() - 3.14
        #         start_pose = np.append(self.option_guide.data[0],[theta])
        #     else:
        #         start_pose = self.init_set.sample()    
        # else:
        #     start_pose = self.init_set.sample()
        
        start_pose = self.init_set.sample()
        # start_pos = list(start_pose[:2])+[0.0]

        # start_orn = self._p.getQuaternionFromEuler([0.,0., start_pose[2]])
        # return list(start_pos)+list(start_orn)
        return start_pose

    def get_active_DOF_start_config(self):
        
        # This assumes the first three are base translation and rotation X Y theta followed by the other joints
        # Transform rotation from euler to quat and use the base translation as-is
        """
        Get the active DOF values from the init sampler and generate a base position + joint values 
        representation that is propagated everywhere in the environment

        :return: Two value list that contains the base X-Y-Theta in the first position and the joint values in the second position
        :rtype: list
        """

        start_config = self.init_set.sample()
        # base_pose = start_pose[:3]
        # start_trans = list(base_pose[:2])+[0.0]
        # start_orn = self._p.getQuaternionFromEuler([0.,0., base_pose[2]])
        # base = (list(start_trans),list(start_orn))
        # start_config = (base, start_pose[3:])
        return start_config

    def get_mp_obs(self, pose):
        closest_point = self.get_closest_point_on_guide(pose)
        return list(self.vector_to_point(pose, closest_point))

    def get_closest_point_on_guide(self,current_config, use_target=True):
        if self.option_guide is not None:
            if use_target is True:
                closest_point = self.option_guide.data[self.target_index]
            else:
                mindist = float('inf')
                closest_point = None
                for i in range(len(self.option_guide.data)):
                    dist = np.linalg.norm(np.array(self.vector_to_point(current_config, self.option_guide.data[i])))
                    if  dist < mindist:
                        mindist = dist
                        closest_point = self.option_guide.data[i]
        else:
            closest_point = [0,0, 0]
        return closest_point

    def vector_to_point(self,current_config,point):
        # p_x, p_y = current_config[0], current_config[1]
        # return (point[0]-p_x, point[1]-p_y)
        current_config = np.asarray(current_config)
        point = np.asarray(point)
        return point - current_config

    def linear_distance(self,current,point):
        current  = np.asarray(current[:2])
        point = np.asarray(point[:2])
        return np.linalg.norm(current - point)
    
    def reach_interface_centroid(self, current_config):
        if self.option_guide is not None:
            if self.target_index <= self.region_switch_point:
                interface_centroid_dist = 1/(np.linalg.norm(np.array(self.vector_to_point(current_config, self.option_guide.data[self.region_switch_point]))))
            else:
                interface_centroid_dist = np.linalg.norm(np.array(self.vector_to_point(current_config, self.option_guide.data[self.region_switch_point])))
            return interface_centroid_dist
        else:
            return 0

    def get_additional_reward(self,current_config):    
        closest_point = np.array(self.get_closest_point_on_guide(current_config))
        last_point = np.array(self.option_guide.data[-1])
        distance = np.linalg.norm(self.vector_to_point(closest_point,last_point))
        distance2 = np.linalg.norm(np.array(self.vector_to_point(current_config,closest_point)))

        waypoints_to_end = 0.2*(len(self.option_guide.data) - self.target_index)
        return -waypoints_to_end-distance2
        # return -distance2


    def simulate_reward(self,pose):
        if self.term_set(pose) == 1:
            return 1000,1
        else:
            return -1,0
    
    def check_dof_threshold(self,current,target): 
        current = np.asarray(current[2:])
        target = np.asarray(target[2:])
        vector = np.abs(current - target)
        if (vector < 0.2).all():
            return True
        else:
            return False


    def compute_reward(self, pose, origpose):
        
        reward = 0

        reached = self.term_set(pose)
        # self.streamDebug("evaluation {}".format(reached))
        last_point = np.array(self.option_guide.data[-1])
        lin_distance = self.linear_distance(pose,last_point)

        if reached == 1: 
            # if lin_distance < 0.4:
            #     self.plot(pose,[1,0,1])
            #     reward = +2000
            #     self.streamDebug("Reached goal. Current pose: {} at step:{} with rewards: {}".format(pose, self.episode_step_ctr,self.info["reward"]+reward))
            # else:
            #     reached = 0
            reward += 1000
            self.streamDebug("Reached goal. Current pose: {} at step:{} with rewards: {}".format(pose, self.episode_step_ctr,self.info["reward"]+reward))
        elif reached == 0:
        # else:
            endpose_reward = self.get_additional_reward(pose)
            reward = -1 + endpose_reward
            ######################################################################
        else:
            reward = -1000
            reached = 2

        if np.linalg.norm(np.array(self.vector_to_point(pose,self.get_closest_point_on_guide(pose)))) < 0.3: 
            if self.target_index < len(self.option_guide.data) - 1:
                self.target_index += 1
            self.local_target = self.option_guide.data[self.target_index]
        
        # pt, collision = self.check_body_collision(body1=self.bodies['robot'],
        #                              body2=self.bodies['env'],
        #                              linkIndex1=self.robot.defaultLink)
        # if collision or reached == -2:
        #     # reward -= 50
        #     reward -= 5
        #     return pt, reward, reached

        return reward, reached

    def get_normalized_vector(self,pose): 
        pose = np.asarray(pose)
        normlaized_pose = (pose - self.ulimits) / (self.ulimits - self.llimits)
        return normlaized_pose


    def apply_action(self,pose,action):
        pass


    def sample_action(self):
        '''
        Sample action by guiding it along the option guide
        '''
        pose = self.robot.getPose()
        cur_closest = self.get_closest_point_on_guide(pose, use_target=True)
        start_dist = np.linalg.norm(np.array(self.vector_to_point(pose, cur_closest)))

        start_lin_dist = self.linear_distance(pose,cur_closest)

        start_time = time.time()

        # while True and time.time()  - start_time < 2:
            # action = self.action_space.sample()
        action = (np.asarray(cur_closest) - np.asarray(pose)).tolist()
            # targetPos = []
            # action_list = []
            # for jstate, act in zip(pose, action):
            #     action_list.append(act)
            #     # newPos = jstate[0]+act * self.motion_clip_factor
            #     newPos = jstate+act
            #     targetPos.append(newPos)

            # end_dist = np.linalg.norm(np.array(self.vector_to_point(self.get_normalized_vector(targetPos), self.get_normalized_vector(cur_closest))))
            # end_lin_dist = self.linear_distance(targetPos,cur_closest)
            # if end_lin_dist < start_lin_dist:
            # # if start_dist > end_dist:
            #     break
        
        return action

    def step_simulate(self):
        for _ in range(self.time_steps):
            self._p.stepSimulation()



    def setPose(self,pose):
        if not self.collision_at_pose(pose):
            self.robot.setPose(pose)
            return False
        return True

    def step(self, action):
        '''
        Only step if done state is not reached
        Log actions as additional information
        '''
        action = np.clip(action,a_min = self.action_space.low, a_max = self.action_space.high)
        temp = np.random.random_sample(size=self.action_space.shape[0])
        action += (0.2*temp - 0.1)
        if self.done:
            return self.last_pose, 0, self.done, self.info
        self.done = False
        self.total_timesteps+=1
        self.action_count += 1 # Track number of actions taken in each episode

        original_pose = self.robot.getPose() # Only use the 
        targetPos = []
        cur = []
        action_list = []
        for jstate, act in zip(original_pose, action):
            cur.append(jstate)
            action_list.append(act)
            newPos = jstate+act * self.motion_clip_factor
            targetPos.append(newPos)

        targetPos = np.clip(np.asarray(targetPos),a_min = self.llimits, a_max = self.ulimits).tolist()
        self.info['action'] = action_list
        cs = self.setPose(targetPos)
        # self.robot.setPose(targetPos)
        robotpose = self.robot.getPose()
    
        # print(twoDimPose)

        # if self.total_timesteps > 10000:
        #     self.info['achieved_pos'].append(twoDimPose[:2]+[0.])
        #     self.info['desired_pos'].append(targetPos[:2]+[-0.05])

        self.info['pos_difference'] = np.array(targetPos[:2])-np.array(robotpose[:2])
        if self.env_mode == "eval":
            reward, env_state = self.simulate_reward(robotpose)
        else:
            reward, env_state = self.compute_reward(robotpose, original_pose)

        # if cs:
        #     reward -= 50
        # self.streamDebug("step reward: {}".format(reward))
        
        self.info['cumreward'] += reward
        self.info['reward'] += reward

        self.episode_step_ctr += 1

        self.info['reward'] += reward

        if self.episode_step_ctr == self.max_ep_len or env_state==1 or env_state==2:
            self.info['episode_reward'].append([self.episode_step_ctr, self.info['reward']])
            self.done = True
            self.info['total_actions'] += self.action_count
            self.info['total_dones'] += 1
            self.info["robot_pose"].append(robotpose)
            self.info["terminal"] = True
            if env_state == 1:
                # Only count up resets where the actions led to the goal
                self.info['action_count'] += self.action_count
                self.info['done_count'] += 1
            else:
                self.plot(robotpose,[0,0,0])

        # mp_obs = self.get_mp_obs(twoDimPose)
        # twoDimPose.append(dist_to_collision)
        # self.last_pose = twoDimPose
        # twoDimPose.extend(mp_obs)

        # return twoDimPose, reward, self.done, self.info

        observation = self.create_observation_from_pose(robotpose)
        # observation.append(dist_to_collision)
        self.last_pose = observation

        # self.streamDebug("Observation len: {}".format(len(observation)))



        return observation, reward, self.done, self.info

    def step2(self, action):
        '''
        Only step if done state is not reached
        Log actions as additional information
        '''
        self.done = False
        self.total_timesteps+=1
        self.action_count += 1 # Track number of actions taken in each episode
        joints_cur_state = self.robot.getJointStates()

        targetPos = []
        cur = []
        action_list = []
        for jstate, act in zip(joints_cur_state, action):
            cur.append(jstate[0])
            action_list.append(act)
            newPos = jstate[0]+act * self.motion_clip_factor

            targetPos.append(newPos)
        self.info['action'] = action_list[:2]
        self.robot.setJoints(targetPos)

        for _ in range(self.time_steps):
            self._p.stepSimulation()



    def plot(self,point,color = [1,0,0]):
        self._p.addUserDebugLine([point[0],point[1],0.0],[point[0],point[1],0.5],color)

    def plot_angles(self, point, color):

        self.plot(point,color)

        x = np.cos(point[2])*0.2+point[0]
        y = np.sin(point[2])*0.2+point[1]
        negx = -np.cos(point[2])*0.2+point[0]
        negy = -np.sin(point[2])*0.2+point[1]
        self._p.addUserDebugLine([x,y,0.05],[negx,negy,0.05],color)


    def logger(self, action):
        curpose = self.robot.getPose()
        self.log_arr.append([action[:2], curpose[0][:2]])

    def make_log_dir(self):
        dirname = []
        if getattr(self.term_set.eval['src'],'id',None) is not None:
            dirname.append(getattr(self.term_set.eval['src'],'id'))
        else:
            dirname.append('point')
        if getattr(self.term_set.eval['dest'],'id',None) is not None:
            dirname.append(getattr(self.term_set.eval['dest'],'id'))
        else:
            dirname.append('point')
        dirname = '_'.join(map(str, dirname))

        return dirname

    def save_log(self):

        logdir = os.path.join('./logs/action_logs/{}'.format(self.make_log_dir()))
        os.makedirs(logdir, exist_ok=True)
        np.save("{}/{}.npy".format(logdir,self.env_id), np.array(self.log_arr))
        return logdir


    def safeRobotSet(self, pose=None):
        if pose is None:
            pose = self.get_start_config()
        
        # collision_state = self.setRobotState(pose, sim=False)
        collision_state = self.setFullRobotState(pose, sim=False)

        while not collision_state:
            pose = self.get_start_config()
           # collision_state = self.setRobotState(pose, sim=False)
            collision_state = self.setFullRobotState(pose, sim=False)

        # self.streamDebug("Setting the robot down at {}".format(pose))
        # self.setRobotState(pose, sim=False)
        self.setFullRobotState(pose, sim=False)

        return pose
   
    def setFullRobotState(self, pose, sim=True):
        if len(pose) == 2:
            pose = pose[1]

        # self.robot.resetFullPose(base_trans, base_rot, joint_pos)
        self.setPose(pose)



        pt, collision = self.check_body_collision(body1=self.bodies['robot'],
                                                  body2=self.bodies['env'],
                                                  linkIndex1=self.robot.defaultLink)
        if collision:
            return False

        if sim:
            for i in range(self.time_steps):
                self._p.stepSimulation()
        
        return True

    def create_observation_from_pose(self,pose):
        mp_obs = self.get_mp_obs(pose)
        return pose + mp_obs
        
    def reset(self):
        # self.streamDebug("Resetting..")
        if self.forked:
            self.start_pybullet()
            self.forked = False

        if self.env_mode == 'train':
            self.done = False
            self.reward = 0
            self.action_count = 0
            self.episode_step_ctr = 0
            self.info['reward'] = 0

            pose = self.get_start_config()
            self.robot_pose = self.safeRobotSet(pose)
            
            self.reset_target_idx()
            self.reset_local_target()

            # twoDimPose = self.get_reduced_2D_pose((self.start_pose_base[:3],self.start_pose_base[3:]))

        else:
            self.robot_pose = self.robot.getPose()
        
        observation = self.create_observation_from_pose(self.robot_pose)
        

        # twoDimPose.append(pt)
        # observation.append(pt)

        # return twoDimPose
        return observation

    def reset_local_target(self):
        if self.option_guide is not None:
                self.local_target = self.option_guide.data[self.target_index]


    def reset_target_idx(self):
        if self.option_guide is None:
            self.target_index = 1
        else:
            self.target_index = int(len(self.option_guide.data) != 1) 
    
    def update_targets(self, option_guide, region_switch_point,gui_off=False):
        self.option_guide = option_guide
        self.region_switch_point = region_switch_point
        self.reset_target_idx()
        if self.mode == p.GUI and not gui_off and self.option_guide is not None:
            self.plot_guide(self.option_guide)

    def close(self):
        self._p.disconnect()



class SACBulletEnv(gym.Env):
    def __init__(self,
                 seed=1337,
                 gui=False,
                 init_config=None,
                 goal_config=None,
                 robot_config=None,
                 env_path=None,
                 forked=True,
                 envid='default',
                 max_ep_len=100,
                 problem_number = -1):
        # np.random.seed(seed=seed)
        self.set_seed(seed=seed)
        if not gui:
            self.mode = p.DIRECT
        else:
            self.mode = p.GUI
        
        self.forked = forked
        self.env_id = envid
        self.time_steps = 2400
        self.interval = 240000.
        self.env_mode='train'
        self.reward = 0
        self.problem_number = problem_number
        self.done = False
        self.action_count = 0
        self.bodies = {}
        self.max_ep_len = max_ep_len
        self.local_target = None
        self.total_timesteps=0
        self.init_config = init_config
        self.goal_config = goal_config
        self.motion_clip_factor = 1
        self.robot_config = robot_config
        self.episode_step_ctr = 0
        self.env_path = env_path
        self.last_pose = None
        self.log_arr = []
        self.info = utils.misc.create_env_info_dict()
        self.ulimits = np.asarray(self.robot_config['ulimits'])
        self.llimits = np.asarray(self.robot_config['llimits'])
        # self.llimits = np.asarray([-2.5,-2.5,-3.14])

        # self.observation_space = gym.spaces.Box(low=np.array([self.robot_config['llimits'][0],
        #                                                       self.robot_config['llimits'][1],
        #                                                       -3.14,
        #                                                       0.0]),
        #                                         high=np.array([self.robot_config['ulimits'][0],
        #                                                        self.robot_config['ulimits'][1],
        #                                                        3.14,
        #                                                        0.5]))
        # self.observation_space = gym.spaces.Box(low=np.array(self.robot_config['llimits']+[0.0]),
                                                # high=np.array(self.robot_config['ulimits']+[0.0]))

        self.action_space = robotsDict[self.robot_config['name']].action_space

        low = self.llimits.tolist() 
        high = self.ulimits.tolist()

        self.observation_space = gym.spaces.Box(low=np.array(low),
                                                high=np.array(high))




        if not self.forked:
            self.start_pybullet()

    def set_seed(self, seed):
        np.random.seed(seed)
        random.seed(seed)

    def streamDebug(self, msg, **kwargs):

        print(__name__+'-'+str(self.env_id)+":"+msg, **kwargs)


    def set_views(self):
        self._p.configureDebugVisualizer(p.COV_ENABLE_SHADOWS,0)
        self._p.configureDebugVisualizer(p.COV_ENABLE_MOUSE_PICKING,0)
        self._p.configureDebugVisualizer(self._p.COV_ENABLE_GUI, 0)
        
        self._p.resetDebugVisualizerCamera(cameraDistance=5.0,
                                           cameraYaw=270.,
                                           cameraPitch=-90.4,
                                           cameraTargetPosition=[0.0, 0.0, 0.0],
                                           physicsClientId=self.phys)


    def start_pybullet(self):
        self._p = bc.BulletClient(connection_mode=self.mode)
        self.phys = self._p._client
        self.set_views()
        self._p.setAdditionalSearchPath(pybullet_data.getDataPath())
        self._p.setGravity(gravX=0, gravY=0, gravZ=-10, physicsClientId=self.phys)
        self.bodies['env'] = self.load_stl(os.path.abspath(self.env_path))

        self.load_robot()
        # self.debugTools()


    def load_robot(self):
        self.streamDebug("Load robot...")
        # self.start_pose_robot = self.get_start_config()
        self.start_pose_robot = self.get_active_DOF_start_config()
        self.robot = robotsDict[self.robot_config['name']](bullet_client=self._p,
                                                           physclient=self.phys,
                                                           urdf_path=os.path.join(os.path.abspath(__file__+'../../../data'),self.robot_config['model_path']),
                                                           start_pos=self.start_pose_robot[0] # Set up the robot base 
                                                           )
        
        if self.robot is None:
            raise ValueError(f"robot be must be one of {list(robotsDict.keys())}")

        self.bodies['robot'] = self.robot.bodyId

        self.setPose(self.start_pose_robot)
        self.streamDebug("Loaded robot")
    
    def collision_at_pose(self, pose):
        # original_pose = self.getRobotCurState()
        original_pose = self.robot.getPose()
        self.robot.setPose(pose)
        pt, collision = self.check_body_collision(body1=self.bodies['robot'],
                                     body2=self.bodies['env'],
                                     linkIndex1=self.robot.defaultLink)
        self.robot.setPose(original_pose)
        return collision


    def check_body_collision(self, body1, body2, linkIndex1=-1, linkIndex2=-1):

        if isinstance(linkIndex1, list):
            closestPts = []
            for index in linkIndex1:
                idx_closest = self._p.getClosestPoints(bodyA=body1,
                                        bodyB=body2,
                                        linkIndexA=index,
                                        linkIndexB=linkIndex2,
                                        distance=0.5)
                closestPts.extend(idx_closest)
        else:
            closestPts = self._p.getClosestPoints(bodyA=body1,
                                              bodyB=body2,
                                              linkIndexA=linkIndex1,
                                              linkIndexB=linkIndex2,
                                              distance=0.5)
        minDistance = 0.5
        if len(closestPts) == 0:
            return minDistance, False
        else:
            for pt in closestPts:
                if pt[8] < minDistance:
                    minDistance = pt[8]
                if pt[8] < 0:
                    return  minDistance, True
            return  minDistance, False

    def load_stl(self, path):

        self.bodies['plane'] = self._p.loadURDF(fileName='plane.urdf',
                                          basePosition=[0,0,0],
                                          baseOrientation=self._p.getQuaternionFromEuler([0,0,0]),
                                          physicsClientId=self.phys)
        col_shape_id = self._p.createCollisionShape(
                        shapeType=p.GEOM_MESH,
                        fileName=path,
                        flags=p.URDF_INITIALIZE_SAT_FEATURES|p.GEOM_FORCE_CONCAVE_TRIMESH
                    )
        viz_shape_id = self._p.createCollisionShape(
                        shapeType=p.GEOM_MESH,
                        fileName=path,
                    )

        body_id = self._p.createMultiBody(
                    baseVisualShapeIndex=viz_shape_id,
                    baseCollisionShapeIndex=col_shape_id,
                    basePosition=(0, 0, -0.25),
                    baseOrientation=(0, 0, 0, 1),
                )

        return body_id
    

    def get_start_config(self):
        
        # if self.option_guide is not None:
        #     # Zero theta might not work in tight spaces. 
        #     # Try to rotate the robot base into a random position
        #     if len(self.option_guide.data)>1:
        #         theta = 3.14*2 * np.random.random() - 3.14
        #         start_pose = np.append(self.option_guide.data[0],[theta])
        #     else:
        #         start_pose = self.init_set.sample()    
        # else:
        #     start_pose = self.init_set.sample()
        
        start_pose = self.init_config
        # start_pos = list(start_pose[:2])+[0.0]

        # start_orn = self._p.getQuaternionFromEuler([0.,0., start_pose[2]])
        # return list(start_pos)+list(start_orn)
        return start_pose

    def get_active_DOF_start_config(self):
        
        # This assumes the first three are base translation and rotation X Y theta followed by the other joints
        # Transform rotation from euler to quat and use the base translation as-is
        """
        Get the active DOF values from the init sampler and generate a base position + joint values 
        representation that is propagated everywhere in the environment

        :return: Two value list that contains the base X-Y-Theta in the first position and the joint values in the second position
        :rtype: list
        """

        start_config = self.init_config
        # base_pose = start_pose[:3]
        # start_trans = list(base_pose[:2])+[0.0]
        # start_orn = self._p.getQuaternionFromEuler([0.,0., base_pose[2]])
        # base = (list(start_trans),list(start_orn))
        # start_config = (base, start_pose[3:])
        return start_config


    def simulate_reward(self,goal): 
        if goal:
            return 1000,1
        else:
            return -1,0
    
    def check_dof_threshold(self,current,target): 
        current = np.asarray(current[2:])
        target = np.asarray(target[2:])
        vector = np.abs(current - target)
        if (vector < 0.2).all():
            return True
        else:
            return False


    def compute_reward(self, pose, origpose):
        done = self.check_dof_threshold(pose,self.goal_config)
        return self.simulate_reward(done)

    def get_normalized_vector(self,pose): 
        pose = np.asarray(pose)
        normlaized_pose = (pose - self.ulimits) / (self.ulimits - self.llimits)
        return normlaized_pose

    def step_simulate(self):
        for _ in range(self.time_steps):
            self._p.stepSimulation()

    def setPose(self,pose):
        if not self.collision_at_pose(pose):
            self.robot.setPose(pose)
            return False
        return True

    def step(self, action):
        '''
        Only step if done state is not reached
        Log actions as additional information
        '''
        action = np.clip(action,a_min = self.action_space.low, a_max = self.action_space.high)
        temp = np.random.random_sample(size=self.action_space.shape[0])
        action += (0.2*temp - 0.1)
        if self.done:
            return self.last_pose, 0, self.done, self.info
        self.done = False
        self.total_timesteps+=1
        self.action_count += 1 # Track number of actions taken in each episode

        original_pose = self.robot.getPose() # Only use the 
        targetPos = []
        cur = []
        action_list = []
        for jstate, act in zip(original_pose, action):
            cur.append(jstate)
            action_list.append(act)
            newPos = jstate+act * self.motion_clip_factor
            targetPos.append(newPos)

        targetPos = np.clip(np.asarray(targetPos),a_min = self.llimits, a_max = self.ulimits).tolist()
        self.info['action'] = action_list
        cs = self.setPose(targetPos)
        # self.robot.setPose(targetPos)
        robotpose = self.robot.getPose()
    
        # print(twoDimPose)

        # if self.total_timesteps > 10000:
        #     self.info['achieved_pos'].append(twoDimPose[:2]+[0.])
        #     self.info['desired_pos'].append(targetPos[:2]+[-0.05])

        self.info['pos_difference'] = np.array(targetPos[:2])-np.array(robotpose[:2])
        reward, env_state = self.compute_reward(robotpose, original_pose)

        self.info['cumreward'] += reward
        self.info['reward'] += reward
        self.episode_step_ctr += 1
        self.info['reward'] += reward
        if self.episode_step_ctr == self.max_ep_len or env_state==1 or env_state==2:
            self.info['episode_reward'].append([self.episode_step_ctr, self.info['reward']])
            self.done = True
            self.info['total_actions'] += self.action_count
            self.info['total_dones'] += 1
            self.info["robot_pose"].append(robotpose)
            self.info["terminal"] = True
            if env_state == 1:
                # Only count up resets where the actions led to the goal
                self.info['action_count'] += self.action_count
                self.info['done_count'] += 1
        return robotpose, reward, self.done, self.info

   
    def logger(self, action):
        curpose = self.robot.getPose()
        self.log_arr.append([action[:2], curpose[0][:2]])

    def make_log_dir(self):
        dirname = 'logs_{}'.format(self.problem_number)
        return dirname

    def save_log(self):

        logdir = os.path.join('./logs/action_logs/{}'.format(self.make_log_dir()))
        os.makedirs(logdir, exist_ok=True)
        np.save("{}/{}.npy".format(logdir,self.env_id), np.array(self.log_arr))
        return logdir


    def safeRobotSet(self, pose=None):
        if pose is None:
            pose = self.get_start_config()
        
        # collision_state = self.setRobotState(pose, sim=False)
        collision_state = self.setFullRobotState(pose, sim=False)

        while not collision_state:
            pose = self.get_start_config()
           # collision_state = self.setRobotState(pose, sim=False)
            collision_state = self.setFullRobotState(pose, sim=False)

        # self.streamDebug("Setting the robot down at {}".format(pose))
        # self.setRobotState(pose, sim=False)
        self.setFullRobotState(pose, sim=False)

        return pose
   
    def setFullRobotState(self, pose, sim=True):

        # if type(pose) == tuple:
        #     basePose = pose[0]
        #     joint_pos = pose[1]
        #     base_trans = [basePose[0][0],basePose[0][1],0.0]
        #     base_rot = basePose[1]
        # else:
        #     basePose = pose
        #     joint_pos = []

        #     if len(basePose) == 2:
        #         base_trans = [basePose[0],basePose[1],0.0]
        #         base_rot = p.getQuaternionFromEuler([0, 0, 0])
        #     elif len(basePose) == 3:
        #         base_trans = [basePose[0], basePose[1], 0.0]
        #         base_rot = p.getQuaternionFromEuler([0, 0, basePose[2]])
        #     elif len(basePose) == 7:
        #         base_trans = basePose[:3]
        #         base_rot = basePose[3:]
        #     else:
        #         raise ValueError("Base pose can either by X-Y-Theta(yaw) or a 3D translation and rotation representation")
        if len(pose) == 2:
            pose = pose[1]

        # self.robot.resetFullPose(base_trans, base_rot, joint_pos)
        self.setPose(pose)



        pt, collision = self.check_body_collision(body1=self.bodies['robot'],
                                                  body2=self.bodies['env'],
                                                  linkIndex1=self.robot.defaultLink)
        if collision:
            return False

        if sim:
            for i in range(self.time_steps):
                self._p.stepSimulation()
        
        return True


    def reset(self):
        # self.streamDebug("Resetting..")
        if self.forked:
            self.start_pybullet()
            self.forked = False

        if self.env_mode == 'train':
            self.done = False
            self.reward = 0
            self.action_count = 0
            self.episode_step_ctr = 0
            self.info['reward'] = 0

            pose = self.get_start_config()
            self.robot_pose = self.safeRobotSet(pose)
            # twoDimPose = self.get_reduced_2D_pose((self.start_pose_base[:3],self.start_pose_base[3:]))

        else:
            self.robot_pose = self.robot.getPose()
        

        # twoDimPose.append(pt)
        # observation.append(pt)

        # return twoDimPose
        return self.robot_pose

    
    def close(self):
        self._p.disconnect()